import numpy as np
import matplotlib.pyplot as plt
import argparse

# Set up argument parser
parser = argparse.ArgumentParser(description='Plot total variation distance histograms')
parser.add_argument('--t', type=int, default=8, help='Number of hash tables')
parser.add_argument('--f', type=int, default=4, help='Number of hash functions per table')
parser.add_argument('--base_path', type=str, default='~/results/sae-softmax/gemma-2-2b/', help='Base path for data files')

args = parser.parse_args()

# Load data
tv_distances_sae = np.load(f'{args.base_path}tv_distances_sae.npy')
tv_distances_lsh = np.load(f'{args.base_path}tv_distances_lsh_t{args.t}_f{args.f}.npy')
tv_distances_lsh_dual = np.load(f'{args.base_path}tv_distances_lsh_dual_t{args.t}_f{args.f}.npy')

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(tv_distances_sae, bins=50, alpha=0.5, label='SAE')
plt.hist(tv_distances_lsh, bins=50, alpha=0.5, label='LSH')
plt.hist(tv_distances_lsh_dual, bins=50, alpha=0.5, label='LSH Dual')

plt.xlabel('Total Variation Distance')
plt.ylabel('Frequency')
plt.title('Histogram of Total Variation Distances')
plt.legend()

# Save the plot
plt.savefig(f'tv_distance_histograms_t{args.t}_f{args.f}.png', dpi=300, bbox_inches='tight')
plt.close()

# Print summary statistics
print("\nSummary Statistics:")
print("-" * 50)
print("\nSAE:")
print(f"Mean: {np.mean(tv_distances_sae):.4f}")
print(f"Std:  {np.std(tv_distances_sae):.4f}")
print(f"Min:  {np.min(tv_distances_sae):.4f}")
print(f"Max:  {np.max(tv_distances_sae):.4f}")

print("\nLSH:")
print(f"Mean: {np.mean(tv_distances_lsh):.4f}")
print(f"Std:  {np.std(tv_distances_lsh):.4f}")
print(f"Min:  {np.min(tv_distances_lsh):.4f}")
print(f"Max:  {np.max(tv_distances_lsh):.4f}")

print("\nLSH Dual:")
print(f"Mean: {np.mean(tv_distances_lsh_dual):.4f}")
print(f"Std:  {np.std(tv_distances_lsh_dual):.4f}")
print(f"Min:  {np.min(tv_distances_lsh_dual):.4f}")
print(f"Max:  {np.max(tv_distances_lsh_dual):.4f}")
